%load_ext autoreload
%autoreload 2
import pickle
from tqdm import tqdm
import numpy as np
import scipy.interpolate
from sklearn.preprocessing import KBinsDiscretizer
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
from modules.utils.general_utils.utilities import group_wise_binning
from plotly.subplots import make_subplots
import plotly.graph_objects as go
def interpolate_paths(z, x, y, c, rep_id):
"""Interpolate lines
"""
INTERP_KIND = {2:"linear", 3:"quadratic", 4:"cubic"}
consecutive_year_blocks = np.where(np.diff(z) != 1)[0] + 1
z_blocks = np.split(z, consecutive_year_blocks)
x_blocks = np.split(x, consecutive_year_blocks)
y_blocks = np.split(y, consecutive_year_blocks)
c_blocks = np.split(c, consecutive_year_blocks)
paths = []
for block_idx, zs in enumerate(z_blocks):
if len(zs) > 1:
kind = INTERP_KIND.get(len(zs), "cubic")
else:
paths.append(
(zs, x_blocks[block_idx], y_blocks[block_idx], c_blocks[block_idx])
)
continue
z = np.linspace(np.min(zs), np.max(zs), 100)
x = scipy.interpolate.interp1d(zs, x_blocks[block_idx], kind=kind)(z)
y = scipy.interpolate.interp1d(zs, y_blocks[block_idx], kind=kind)(z)
c = scipy.interpolate.interp1d(zs, c_blocks[block_idx], kind=kind)(z)
paths.append((z, x, y, c))
return paths
with (open('results\\saved_data_containers\\melchior.pkl', 'rb')) as container:
DATA_CONTAINER = pickle.load(container)
predictions = DATA_CONTAINER['prediction_ds']['tar_activity']
contexts = DATA_CONTAINER['context']
predictions = [predictions[i] for i in range(5)]
contexts = [contexts[i] for i in range(5)]
predictions = np.hstack(group_wise_binning(predictions, n_bins=100, grouper=contexts))
df = pd.read_csv('results\\saved_dim_reduction\\melchior_eng_emb_temporal.csv')
df = pd.read_csv('results\\saved_dim_reduction\\melchior_eng_emb_temporal.csv')
df['Future Session Activity'] = predictions
df = df[df['context'] == 6]
users = df.groupby(['user_id'])['session'].max() +1
users = users.reset_index()
users = users[users['session'] == 4]['user_id'].values
df = df[df['user_id'].isin(users)]
discretizer = KBinsDiscretizer(n_bins=9, encode='ordinal')
variability_rank = df.groupby('user_id')['Future Session Activity'].agg(
lambda x: np.var(x.values)).reset_index()
variability_rank['rank'] = discretizer.fit_transform(variability_rank['Future Session Activity'].values.reshape((-1, 1)))
sns.histplot(variability_rank['rank'].values)
zoom=3.5
fig = make_subplots(
rows=1,
cols=3,
specs=[
[{'type': 'scatter3d'}, {'type': 'scatter3d'}, {'type': 'scatter3d'}]
],
subplot_titles=(
'Low Variability',
'Medium Variability',
'High Variability'
),
horizontal_spacing = 0.01,
vertical_spacing = 0.05,
shared_xaxes=True,
shared_yaxes=True
)
locations = [
(1, 1),
(1, 2),
(1, 3)
]
for index, rank in enumerate([0, 4, 8]):
location = locations[index]
unique_ids = variability_rank[variability_rank['rank'] == rank]['user_id'].values
print(len(unique_ids))
unique_ids = np.random.choice(
unique_ids,
min(len(unique_ids), 700),
replace=False
)
traces = []
for unique_id in unique_ids:
z = df.session[df.user_id == unique_id].values
x = df.UMAP_1[df.user_id == unique_id].values
y = df.UMAP_2[df.user_id == unique_id].values
c = df['Future Session Activity'][df.user_id == unique_id]
for z, x, y, c in interpolate_paths(z, x, y, c, unique_id):
trace = go.Scatter3d(
x=x, y=z, z=y,
mode='lines',
line=dict(
color=c,
cmin=0,
cmid=50,
cmax=100,
cauto=False,
colorscale='RdBu',
colorbar=dict(),
width=0.6,
),
opacity=0.6,
)
fig.add_trace(trace, row=location[0], col=location[1])
fig.update_layout(
width=1050,
height=500,
margin=dict(r=1, l=1),
showlegend=False,
autosize=False,
template="plotly_white",
)
fig.update_layout(
scene_aspectmode='manual',
scene_aspectratio=dict(x=1, y=3, z=1),
scene_camera=dict(
up=dict(x=0, y=0, z=1),
center=dict(x=0.5, y=0.5, z=0),
eye=dict(x=(0.5)*zoom, y=(0.8)*zoom, z=(0.75)*zoom)
)
)
fig.update_layout(
scene2_aspectmode='manual',
scene2_aspectratio=dict(x=1, y=3, z=1),
scene2_camera=dict(
up=dict(x=0, y=0, z=1),
center=dict(x=0.5, y=0.5, z=0),
eye=dict(x=(0.5)*zoom, y=(0.8)*zoom, z=(0.75)*zoom)
)
)
fig.update_layout(
scene3_aspectmode='manual',
scene3_aspectratio=dict(x=1, y=3, z=1),
scene3_camera=dict(
up=dict(x=0, y=0, z=1),
center=dict(x=0.5, y=0.5, z=0),
eye=dict(x=(0.5)*zoom, y=(0.8)*zoom, z=(0.75)*zoom)
)
)
fig.update_scenes(
xaxis_title_text='UMAP 1',
zaxis_title_text='UMAP 2',
yaxis_title_text=r"$t$",
yaxis = dict(
tickmode = 'array',
tickvals = [0, 1, 2, 3],
ticktext = [1, 2, 3, 4]
)
)
fig.show()